{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab4d0442",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e9d6261b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    T5ForConditionalGeneration,\n",
    "    AutoConfig,\n",
    ")\n",
    "from pytorch_lightning import LightningModule\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d426b62d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('./out-base-1.1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "94be4a10",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Module(LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        config = AutoConfig.from_pretrained(\n",
    "            './out-base-1.1'\n",
    "        )\n",
    "        self.model = T5ForConditionalGeneration.from_pretrained(\n",
    "            './out-base-1.1',\n",
    "            config=config,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e4b5d292",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 20G\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:43 'model-epoch=00-step=2800.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:48 'model-epoch=00-step=3000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:54 'model-epoch=00-step=3200.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 16:38 'model-epoch=01-step=26000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 00:21 'model-epoch=02-step=44000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 00:45 'model-epoch=03-step=45000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 01:09 'model-epoch=03-step=46000.ckpt'\r\n"
     ]
    }
   ],
   "source": [
    "!ls -lh logs/base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74be228f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Module()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fc74cda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cd9d6272",
   "metadata": {},
   "outputs": [],
   "source": [
    "old_weights = torch.load('logs/base/model-epoch=03-step=46000.ckpt',\n",
    "                             map_location=torch.device('cpu'))['state_dict'].items()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "652becba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.shared.weight model.shared.weight\n",
      "model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
      "model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wo.weight model.encoder.block.0.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.0.layer.1.layer_norm.weight model.encoder.block.0.layer.1.layer_norm.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.q.weight model.encoder.block.1.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.k.weight model.encoder.block.1.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.v.weight model.encoder.block.1.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.o.weight model.encoder.block.1.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.1.layer.0.layer_norm.weight model.encoder.block.1.layer.0.layer_norm.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wi_0.weight model.encoder.block.1.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wi_1.weight model.encoder.block.1.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wo.weight model.encoder.block.1.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.1.layer.1.layer_norm.weight model.encoder.block.1.layer.1.layer_norm.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.q.weight model.encoder.block.2.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.k.weight model.encoder.block.2.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.v.weight model.encoder.block.2.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.o.weight model.encoder.block.2.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.2.layer.0.layer_norm.weight model.encoder.block.2.layer.0.layer_norm.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wi_0.weight model.encoder.block.2.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wi_1.weight model.encoder.block.2.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wo.weight model.encoder.block.2.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.2.layer.1.layer_norm.weight model.encoder.block.2.layer.1.layer_norm.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.q.weight model.encoder.block.3.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.k.weight model.encoder.block.3.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.v.weight model.encoder.block.3.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.o.weight model.encoder.block.3.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.3.layer.0.layer_norm.weight model.encoder.block.3.layer.0.layer_norm.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wi_0.weight model.encoder.block.3.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wi_1.weight model.encoder.block.3.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wo.weight model.encoder.block.3.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.3.layer.1.layer_norm.weight model.encoder.block.3.layer.1.layer_norm.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.q.weight model.encoder.block.4.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.k.weight model.encoder.block.4.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.v.weight model.encoder.block.4.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.o.weight model.encoder.block.4.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.4.layer.0.layer_norm.weight model.encoder.block.4.layer.0.layer_norm.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wi_0.weight model.encoder.block.4.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wi_1.weight model.encoder.block.4.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wo.weight model.encoder.block.4.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.4.layer.1.layer_norm.weight model.encoder.block.4.layer.1.layer_norm.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.q.weight model.encoder.block.5.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.k.weight model.encoder.block.5.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.v.weight model.encoder.block.5.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.o.weight model.encoder.block.5.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.5.layer.0.layer_norm.weight model.encoder.block.5.layer.0.layer_norm.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wi_0.weight model.encoder.block.5.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wi_1.weight model.encoder.block.5.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wo.weight model.encoder.block.5.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.5.layer.1.layer_norm.weight model.encoder.block.5.layer.1.layer_norm.weight\n",
      "model.encoder.block.6.layer.0.SelfAttention.q.weight model.encoder.block.6.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.6.layer.0.SelfAttention.k.weight model.encoder.block.6.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.6.layer.0.SelfAttention.v.weight model.encoder.block.6.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.6.layer.0.SelfAttention.o.weight model.encoder.block.6.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.6.layer.0.layer_norm.weight model.encoder.block.6.layer.0.layer_norm.weight\n",
      "model.encoder.block.6.layer.1.DenseReluDense.wi_0.weight model.encoder.block.6.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.6.layer.1.DenseReluDense.wi_1.weight model.encoder.block.6.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.6.layer.1.DenseReluDense.wo.weight model.encoder.block.6.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.6.layer.1.layer_norm.weight model.encoder.block.6.layer.1.layer_norm.weight\n",
      "model.encoder.block.7.layer.0.SelfAttention.q.weight model.encoder.block.7.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.7.layer.0.SelfAttention.k.weight model.encoder.block.7.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.7.layer.0.SelfAttention.v.weight model.encoder.block.7.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.7.layer.0.SelfAttention.o.weight model.encoder.block.7.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.7.layer.0.layer_norm.weight model.encoder.block.7.layer.0.layer_norm.weight\n",
      "model.encoder.block.7.layer.1.DenseReluDense.wi_0.weight model.encoder.block.7.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.7.layer.1.DenseReluDense.wi_1.weight model.encoder.block.7.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.7.layer.1.DenseReluDense.wo.weight model.encoder.block.7.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.7.layer.1.layer_norm.weight model.encoder.block.7.layer.1.layer_norm.weight\n",
      "model.encoder.block.8.layer.0.SelfAttention.q.weight model.encoder.block.8.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.8.layer.0.SelfAttention.k.weight model.encoder.block.8.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.8.layer.0.SelfAttention.v.weight model.encoder.block.8.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.8.layer.0.SelfAttention.o.weight model.encoder.block.8.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.8.layer.0.layer_norm.weight model.encoder.block.8.layer.0.layer_norm.weight\n",
      "model.encoder.block.8.layer.1.DenseReluDense.wi_0.weight model.encoder.block.8.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.8.layer.1.DenseReluDense.wi_1.weight model.encoder.block.8.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.8.layer.1.DenseReluDense.wo.weight model.encoder.block.8.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.8.layer.1.layer_norm.weight model.encoder.block.8.layer.1.layer_norm.weight\n",
      "model.encoder.block.9.layer.0.SelfAttention.q.weight model.encoder.block.9.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.9.layer.0.SelfAttention.k.weight model.encoder.block.9.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.9.layer.0.SelfAttention.v.weight model.encoder.block.9.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.9.layer.0.SelfAttention.o.weight model.encoder.block.9.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.9.layer.0.layer_norm.weight model.encoder.block.9.layer.0.layer_norm.weight\n",
      "model.encoder.block.9.layer.1.DenseReluDense.wi_0.weight model.encoder.block.9.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.9.layer.1.DenseReluDense.wi_1.weight model.encoder.block.9.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.9.layer.1.DenseReluDense.wo.weight model.encoder.block.9.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.9.layer.1.layer_norm.weight model.encoder.block.9.layer.1.layer_norm.weight\n",
      "model.encoder.block.10.layer.0.SelfAttention.q.weight model.encoder.block.10.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.10.layer.0.SelfAttention.k.weight model.encoder.block.10.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.10.layer.0.SelfAttention.v.weight model.encoder.block.10.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.10.layer.0.SelfAttention.o.weight model.encoder.block.10.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.10.layer.0.layer_norm.weight model.encoder.block.10.layer.0.layer_norm.weight\n",
      "model.encoder.block.10.layer.1.DenseReluDense.wi_0.weight model.encoder.block.10.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.10.layer.1.DenseReluDense.wi_1.weight model.encoder.block.10.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.10.layer.1.DenseReluDense.wo.weight model.encoder.block.10.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.10.layer.1.layer_norm.weight model.encoder.block.10.layer.1.layer_norm.weight\n",
      "model.encoder.block.11.layer.0.SelfAttention.q.weight model.encoder.block.11.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.11.layer.0.SelfAttention.k.weight model.encoder.block.11.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.11.layer.0.SelfAttention.v.weight model.encoder.block.11.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.11.layer.0.SelfAttention.o.weight model.encoder.block.11.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.11.layer.0.layer_norm.weight model.encoder.block.11.layer.0.layer_norm.weight\n",
      "model.encoder.block.11.layer.1.DenseReluDense.wi_0.weight model.encoder.block.11.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.11.layer.1.DenseReluDense.wi_1.weight model.encoder.block.11.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.11.layer.1.DenseReluDense.wo.weight model.encoder.block.11.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.11.layer.1.layer_norm.weight model.encoder.block.11.layer.1.layer_norm.weight\n",
      "model.encoder.final_layer_norm.weight model.encoder.final_layer_norm.weight\n",
      "model.decoder.embed_tokens.weight model.decoder.embed_tokens.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.q.weight model.decoder.block.0.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.k.weight model.decoder.block.0.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.v.weight model.decoder.block.0.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.o.weight model.decoder.block.0.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
      "model.decoder.block.0.layer.0.layer_norm.weight model.decoder.block.0.layer.0.layer_norm.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.q.weight model.decoder.block.0.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.k.weight model.decoder.block.0.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.v.weight model.decoder.block.0.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.o.weight model.decoder.block.0.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.0.layer.1.layer_norm.weight model.decoder.block.0.layer.1.layer_norm.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wi_0.weight model.decoder.block.0.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wi_1.weight model.decoder.block.0.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wo.weight model.decoder.block.0.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.0.layer.2.layer_norm.weight model.decoder.block.0.layer.2.layer_norm.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.q.weight model.decoder.block.1.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.k.weight model.decoder.block.1.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.v.weight model.decoder.block.1.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.o.weight model.decoder.block.1.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.1.layer.0.layer_norm.weight model.decoder.block.1.layer.0.layer_norm.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.q.weight model.decoder.block.1.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.k.weight model.decoder.block.1.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.v.weight model.decoder.block.1.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.o.weight model.decoder.block.1.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.1.layer.1.layer_norm.weight model.decoder.block.1.layer.1.layer_norm.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wi_0.weight model.decoder.block.1.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wi_1.weight model.decoder.block.1.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wo.weight model.decoder.block.1.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.1.layer.2.layer_norm.weight model.decoder.block.1.layer.2.layer_norm.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.q.weight model.decoder.block.2.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.k.weight model.decoder.block.2.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.v.weight model.decoder.block.2.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.o.weight model.decoder.block.2.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.2.layer.0.layer_norm.weight model.decoder.block.2.layer.0.layer_norm.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.q.weight model.decoder.block.2.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.k.weight model.decoder.block.2.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.v.weight model.decoder.block.2.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.o.weight model.decoder.block.2.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.2.layer.1.layer_norm.weight model.decoder.block.2.layer.1.layer_norm.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wi_0.weight model.decoder.block.2.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wi_1.weight model.decoder.block.2.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wo.weight model.decoder.block.2.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.2.layer.2.layer_norm.weight model.decoder.block.2.layer.2.layer_norm.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.q.weight model.decoder.block.3.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.k.weight model.decoder.block.3.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.v.weight model.decoder.block.3.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.o.weight model.decoder.block.3.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.3.layer.0.layer_norm.weight model.decoder.block.3.layer.0.layer_norm.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.q.weight model.decoder.block.3.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.k.weight model.decoder.block.3.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.v.weight model.decoder.block.3.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.o.weight model.decoder.block.3.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.3.layer.1.layer_norm.weight model.decoder.block.3.layer.1.layer_norm.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wi_0.weight model.decoder.block.3.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wi_1.weight model.decoder.block.3.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wo.weight model.decoder.block.3.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.3.layer.2.layer_norm.weight model.decoder.block.3.layer.2.layer_norm.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.q.weight model.decoder.block.4.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.k.weight model.decoder.block.4.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.v.weight model.decoder.block.4.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.o.weight model.decoder.block.4.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.4.layer.0.layer_norm.weight model.decoder.block.4.layer.0.layer_norm.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.q.weight model.decoder.block.4.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.k.weight model.decoder.block.4.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.v.weight model.decoder.block.4.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.o.weight model.decoder.block.4.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.4.layer.1.layer_norm.weight model.decoder.block.4.layer.1.layer_norm.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wi_0.weight model.decoder.block.4.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wi_1.weight model.decoder.block.4.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wo.weight model.decoder.block.4.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.4.layer.2.layer_norm.weight model.decoder.block.4.layer.2.layer_norm.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.q.weight model.decoder.block.5.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.k.weight model.decoder.block.5.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.v.weight model.decoder.block.5.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.o.weight model.decoder.block.5.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.5.layer.0.layer_norm.weight model.decoder.block.5.layer.0.layer_norm.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.q.weight model.decoder.block.5.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.k.weight model.decoder.block.5.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.v.weight model.decoder.block.5.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.o.weight model.decoder.block.5.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.5.layer.1.layer_norm.weight model.decoder.block.5.layer.1.layer_norm.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wi_0.weight model.decoder.block.5.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wi_1.weight model.decoder.block.5.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wo.weight model.decoder.block.5.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.5.layer.2.layer_norm.weight model.decoder.block.5.layer.2.layer_norm.weight\n",
      "model.decoder.block.6.layer.0.SelfAttention.q.weight model.decoder.block.6.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.6.layer.0.SelfAttention.k.weight model.decoder.block.6.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.6.layer.0.SelfAttention.v.weight model.decoder.block.6.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.6.layer.0.SelfAttention.o.weight model.decoder.block.6.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.6.layer.0.layer_norm.weight model.decoder.block.6.layer.0.layer_norm.weight\n",
      "model.decoder.block.6.layer.1.EncDecAttention.q.weight model.decoder.block.6.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.6.layer.1.EncDecAttention.k.weight model.decoder.block.6.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.6.layer.1.EncDecAttention.v.weight model.decoder.block.6.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.6.layer.1.EncDecAttention.o.weight model.decoder.block.6.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.6.layer.1.layer_norm.weight model.decoder.block.6.layer.1.layer_norm.weight\n",
      "model.decoder.block.6.layer.2.DenseReluDense.wi_0.weight model.decoder.block.6.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.6.layer.2.DenseReluDense.wi_1.weight model.decoder.block.6.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.6.layer.2.DenseReluDense.wo.weight model.decoder.block.6.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.6.layer.2.layer_norm.weight model.decoder.block.6.layer.2.layer_norm.weight\n",
      "model.decoder.block.7.layer.0.SelfAttention.q.weight model.decoder.block.7.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.7.layer.0.SelfAttention.k.weight model.decoder.block.7.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.7.layer.0.SelfAttention.v.weight model.decoder.block.7.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.7.layer.0.SelfAttention.o.weight model.decoder.block.7.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.7.layer.0.layer_norm.weight model.decoder.block.7.layer.0.layer_norm.weight\n",
      "model.decoder.block.7.layer.1.EncDecAttention.q.weight model.decoder.block.7.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.7.layer.1.EncDecAttention.k.weight model.decoder.block.7.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.7.layer.1.EncDecAttention.v.weight model.decoder.block.7.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.7.layer.1.EncDecAttention.o.weight model.decoder.block.7.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.7.layer.1.layer_norm.weight model.decoder.block.7.layer.1.layer_norm.weight\n",
      "model.decoder.block.7.layer.2.DenseReluDense.wi_0.weight model.decoder.block.7.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.7.layer.2.DenseReluDense.wi_1.weight model.decoder.block.7.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.7.layer.2.DenseReluDense.wo.weight model.decoder.block.7.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.7.layer.2.layer_norm.weight model.decoder.block.7.layer.2.layer_norm.weight\n",
      "model.decoder.block.8.layer.0.SelfAttention.q.weight model.decoder.block.8.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.8.layer.0.SelfAttention.k.weight model.decoder.block.8.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.8.layer.0.SelfAttention.v.weight model.decoder.block.8.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.8.layer.0.SelfAttention.o.weight model.decoder.block.8.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.8.layer.0.layer_norm.weight model.decoder.block.8.layer.0.layer_norm.weight\n",
      "model.decoder.block.8.layer.1.EncDecAttention.q.weight model.decoder.block.8.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.8.layer.1.EncDecAttention.k.weight model.decoder.block.8.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.8.layer.1.EncDecAttention.v.weight model.decoder.block.8.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.8.layer.1.EncDecAttention.o.weight model.decoder.block.8.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.8.layer.1.layer_norm.weight model.decoder.block.8.layer.1.layer_norm.weight\n",
      "model.decoder.block.8.layer.2.DenseReluDense.wi_0.weight model.decoder.block.8.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.8.layer.2.DenseReluDense.wi_1.weight model.decoder.block.8.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.8.layer.2.DenseReluDense.wo.weight model.decoder.block.8.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.8.layer.2.layer_norm.weight model.decoder.block.8.layer.2.layer_norm.weight\n",
      "model.decoder.block.9.layer.0.SelfAttention.q.weight model.decoder.block.9.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.9.layer.0.SelfAttention.k.weight model.decoder.block.9.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.9.layer.0.SelfAttention.v.weight model.decoder.block.9.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.9.layer.0.SelfAttention.o.weight model.decoder.block.9.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.9.layer.0.layer_norm.weight model.decoder.block.9.layer.0.layer_norm.weight\n",
      "model.decoder.block.9.layer.1.EncDecAttention.q.weight model.decoder.block.9.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.9.layer.1.EncDecAttention.k.weight model.decoder.block.9.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.9.layer.1.EncDecAttention.v.weight model.decoder.block.9.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.9.layer.1.EncDecAttention.o.weight model.decoder.block.9.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.9.layer.1.layer_norm.weight model.decoder.block.9.layer.1.layer_norm.weight\n",
      "model.decoder.block.9.layer.2.DenseReluDense.wi_0.weight model.decoder.block.9.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.9.layer.2.DenseReluDense.wi_1.weight model.decoder.block.9.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.9.layer.2.DenseReluDense.wo.weight model.decoder.block.9.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.9.layer.2.layer_norm.weight model.decoder.block.9.layer.2.layer_norm.weight\n",
      "model.decoder.block.10.layer.0.SelfAttention.q.weight model.decoder.block.10.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.10.layer.0.SelfAttention.k.weight model.decoder.block.10.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.10.layer.0.SelfAttention.v.weight model.decoder.block.10.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.10.layer.0.SelfAttention.o.weight model.decoder.block.10.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.10.layer.0.layer_norm.weight model.decoder.block.10.layer.0.layer_norm.weight\n",
      "model.decoder.block.10.layer.1.EncDecAttention.q.weight model.decoder.block.10.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.10.layer.1.EncDecAttention.k.weight model.decoder.block.10.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.10.layer.1.EncDecAttention.v.weight model.decoder.block.10.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.10.layer.1.EncDecAttention.o.weight model.decoder.block.10.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.10.layer.1.layer_norm.weight model.decoder.block.10.layer.1.layer_norm.weight\n",
      "model.decoder.block.10.layer.2.DenseReluDense.wi_0.weight model.decoder.block.10.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.10.layer.2.DenseReluDense.wi_1.weight model.decoder.block.10.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.10.layer.2.DenseReluDense.wo.weight model.decoder.block.10.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.10.layer.2.layer_norm.weight model.decoder.block.10.layer.2.layer_norm.weight\n",
      "model.decoder.block.11.layer.0.SelfAttention.q.weight model.decoder.block.11.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.11.layer.0.SelfAttention.k.weight model.decoder.block.11.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.11.layer.0.SelfAttention.v.weight model.decoder.block.11.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.11.layer.0.SelfAttention.o.weight model.decoder.block.11.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.11.layer.0.layer_norm.weight model.decoder.block.11.layer.0.layer_norm.weight\n",
      "model.decoder.block.11.layer.1.EncDecAttention.q.weight model.decoder.block.11.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.11.layer.1.EncDecAttention.k.weight model.decoder.block.11.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.11.layer.1.EncDecAttention.v.weight model.decoder.block.11.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.11.layer.1.EncDecAttention.o.weight model.decoder.block.11.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.11.layer.1.layer_norm.weight model.decoder.block.11.layer.1.layer_norm.weight\n",
      "model.decoder.block.11.layer.2.DenseReluDense.wi_0.weight model.decoder.block.11.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.11.layer.2.DenseReluDense.wi_1.weight model.decoder.block.11.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.11.layer.2.DenseReluDense.wo.weight model.decoder.block.11.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.11.layer.2.layer_norm.weight model.decoder.block.11.layer.2.layer_norm.weight\n",
      "model.decoder.final_layer_norm.weight model.decoder.final_layer_norm.weight\n",
      "model.lm_head.weight model.lm_head.weight\n"
     ]
    }
   ],
   "source": [
    "for k, v in old_weights:\n",
    "    new_k = k.replace('._orig_mod', '')\n",
    "    print(k, new_k)\n",
    "    weights[new_k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "839ae22f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c698e532",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.copied_utils import (\n",
    "    compute_input_and_target_lengths,\n",
    "    DataCollatorForT5MLM,\n",
    "    tokenize_function,\n",
    "    DataCollatorForNI,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "60e223c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(568, 114)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "before_mask_input_length, target_length = compute_input_and_target_lengths(\n",
    "    inputs_length=512,\n",
    "    noise_density=0.15,\n",
    "    mean_noise_span_length=3.0,\n",
    ")\n",
    "before_mask_input_length, target_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e78b0fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "from streaming import StreamingDataset\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "class Int32(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.int32)\n",
    "\n",
    "\n",
    "_encodings['int32'] = Int32\n",
    "\n",
    "\n",
    "class DatasetFixed(torch.utils.data.Dataset):\n",
    "    def __init__(self, local):\n",
    "        self.dataset = StreamingDataset(local=local)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.dataset[idx]\n",
    "        data.pop('token_type_ids', None)\n",
    "        for k in data.keys():\n",
    "            data[k] = data[k].astype(np.int64)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d3c2bef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = DatasetFixed(local='/home/ubuntu/nanot5-512')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "317b0cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForT5MLM(\n",
    "    tokenizer=tokenizer,\n",
    "    noise_density=0.15,\n",
    "    mean_noise_span_length=3.0,\n",
    "    input_length=512,\n",
    "    target_length=target_length,\n",
    "    pad_token_id=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4028b237",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = [dataset[i] for i in range(5000000, 5000001, 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d788fba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "padded = data_collator(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ad3a2d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = model.model(**padded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ba335b10",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.4642,  1.9711,  8.1804,  ..., -1.3669, -1.2231, -1.7239],\n",
       "         [-2.6246, -0.4478, 13.4775,  ..., -0.4373, -1.4065, -1.1230],\n",
       "         [-2.0779,  2.4580, 14.7889,  ..., -1.0686, -1.1595, -1.2585],\n",
       "         ...,\n",
       "         [-1.3984,  2.0701,  8.1561,  ..., -1.2381, -2.6082, -1.7596],\n",
       "         [-2.2733, -0.1507,  8.6637,  ..., -2.3320, -1.3610, -2.8000],\n",
       "         [-0.5552,  6.5587, 32.7664,  ..., -0.7088, -1.3464, -2.2160]]],\n",
       "       grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "o.logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e29488ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<extra_id_40> dll)</s><extra_id_59> Bawal Exclusive pi kilang S<extra_id_38> husband reseller<extra_id_18> tanya<extra_id_46> Bini<extra_id_66>..<extra_id_71> la sk<extra_id_67>t<extra_id_6> hubungan<extra_id_53>wah pers<extra_id_10> tiada jawapan ditemui, kerana yang<extra_id_61>anya. Bagaikan tersedar dari lena<extra_id_2> penulis mengakhiri<extra_id_84>ang dada. Mener<extra_id_8>.<extra_id_28> pasti<extra_id_70></s><extra_id_76><s> Ad<extra_id_62>2<extra_id_54> dah t<extra_id_12>osaur<extra_id_0> Best x??<extra_id_75> lama lagi<extra_id_85> duit kt die tuk<extra_id_83>agaknyalahhhh</s><s> Edited by<extra_id_52>antang<extra_id_89>....<extra_id_77>...ni ha...mkn la</s>'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['labels']\n",
    "tokenizer.decode(padded['labels'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3206b0ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<extra_id_40></s></s></s><extra_id_59> Sarac S<extra_id_38> S S S<extra_id_38> kawan<extra_id_18>ign<extra_id_18> kata kat makinodoh dia.<extra_id_71> sk nak<extra_id_67>.<extra_id_6> hubungan<extra_id_53>wah pers<extra_id_10> penulis siapa yang. penulis si<extra_id_61>anya. Namunimbangkan tidakedar<extra_id_2> segalaiku<extra_id_2> penulis menerimahiri<extra_id_84>ang dada, Mener<extra_id_8>.<extra_id_28> pasti<extra_id_70></s><extra_id_76><s> Ad<extra_id_62>2<extra_id_54> t t<extra_id_12>our<extra_id_0></s> ke?<extra_id_75> lama<extra_id_85><extra_id_85> die utk die utk<extra_id_83></s>aknya</s></s></s></s><s> Edited by<extra_id_52>n<extra_id_89>.<extra_id_77> tu</s> mestiuk</s>cmn</s>'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(o.logits.argmax(-1).detach().cpu().numpy()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14cbcce3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
